/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.util.hll;

import org.apache.solr.util.LongIterator;

/**
 * A vector (array) of bits that is accessed in units ("registers") of <code>width</code> bits which
 * are stored as 64bit "words" (<code>long</code>s). In this context a register is at most 64bits.
 */
class BitVector implements Cloneable {
  // NOTE:  in this context, a word is 64bits

  // rather than doing division to determine how a bit index fits into 64bit
  // words (i.e. longs), bit shifting is used
  private static final int LOG2_BITS_PER_WORD = 6 /*=>64bits*/;
  private static final int BITS_PER_WORD = 1 << LOG2_BITS_PER_WORD;
  private static final int BITS_PER_WORD_MASK = BITS_PER_WORD - 1;

  // ditto from above but for bytes (for output)
  private static final int LOG2_BITS_PER_BYTE = 3 /*=>8bits*/;
  public static final int BITS_PER_BYTE = 1 << LOG2_BITS_PER_BYTE;

  // ========================================================================
  public static final int BYTES_PER_WORD = 8 /*8 bytes in a long*/;

  // ************************************************************************
  // 64bit words
  private final long[] words;

  public final long[] words() {
    return words;
  }

  public final int wordCount() {
    return words.length;
  }

  public final int byteCount() {
    return wordCount() * BYTES_PER_WORD;
  }

  // the width of a register in bits (this cannot be more than 64 (the word size))
  private final int registerWidth;

  public final int registerWidth() {
    return registerWidth;
  }

  private final long count;

  // ------------------------------------------------------------------------
  private final long registerMask;

  // ========================================================================
  /**
   * @param width the width of each register. This cannot be negative or zero or greater than 63
   *     (the signed word size).
   * @param count the number of registers. This cannot be negative or zero
   */
  public BitVector(final int width, final long count) {
    // ceil((width * count)/BITS_PER_WORD)
    this.words = new long[(int) (((width * count) + BITS_PER_WORD_MASK) >>> LOG2_BITS_PER_WORD)];
    this.registerWidth = width;
    this.count = count;

    this.registerMask = (1L << width) - 1;
  }

  // ========================================================================
  /**
   * @param registerIndex the index of the register whose value is to be retrieved. This cannot be
   *     negative.
   * @return the value at the specified register index
   * @see #setRegister(long, long)
   * @see #setMaxRegister(long, long)
   */
  // NOTE:  if this changes then setMaxRegister() must change
  public long getRegister(final long registerIndex) {
    final long bitIndex = registerIndex * registerWidth;
    final int firstWordIndex =
        (int) (bitIndex >>> LOG2_BITS_PER_WORD) /*aka (bitIndex / BITS_PER_WORD)*/;
    final int secondWordIndex =
        (int) ((bitIndex + registerWidth - 1) >>> LOG2_BITS_PER_WORD) /*see above*/;
    final int bitRemainder =
        (int) (bitIndex & BITS_PER_WORD_MASK) /*aka (bitIndex % BITS_PER_WORD)*/;

    if (firstWordIndex == secondWordIndex)
      return ((words[firstWordIndex] >>> bitRemainder) & registerMask);
    /* else -- register spans words */
    return (words[firstWordIndex] >>> bitRemainder) /*no need to mask since at top of word*/
        | ((words[secondWordIndex] << (BITS_PER_WORD - bitRemainder)) & registerMask);
  }

  /**
   * @param registerIndex the index of the register whose value is to be set. This cannot be
   *     negative
   * @param value the value to set in the register
   * @see #getRegister(long)
   * @see #setMaxRegister(long, long)
   */
  // NOTE:  if this changes then setMaxRegister() must change
  public void setRegister(final long registerIndex, final long value) {
    final long bitIndex = registerIndex * registerWidth;
    final int firstWordIndex =
        (int) (bitIndex >>> LOG2_BITS_PER_WORD) /*aka (bitIndex / BITS_PER_WORD)*/;
    final int secondWordIndex =
        (int) ((bitIndex + registerWidth - 1) >>> LOG2_BITS_PER_WORD) /*see above*/;
    final int bitRemainder =
        (int) (bitIndex & BITS_PER_WORD_MASK) /*aka (bitIndex % BITS_PER_WORD)*/;

    final long words[] = this.words /*for convenience/performance*/;
    if (firstWordIndex == secondWordIndex) {
      // clear then set
      words[firstWordIndex] &= ~(registerMask << bitRemainder);
      words[firstWordIndex] |= (value << bitRemainder);
    } else {
      /*register spans words*/
      // clear then set each partial word
      words[firstWordIndex] &= (1L << bitRemainder) - 1;
      words[firstWordIndex] |= (value << bitRemainder);

      words[secondWordIndex] &= ~(registerMask >>> (BITS_PER_WORD - bitRemainder));
      words[secondWordIndex] |= (value >>> (BITS_PER_WORD - bitRemainder));
    }
  }

  // ------------------------------------------------------------------------
  /**
   * @return a <code>LongIterator</code> for iterating starting at the register with index zero.
   *     This will never be <code>null</code>.
   */
  public LongIterator registerIterator() {
    return new LongIterator() {
      final int registerWidth = BitVector.this.registerWidth;
      final long[] words = BitVector.this.words;
      final long registerMask = BitVector.this.registerMask;

      // register setup
      long registerIndex = 0;
      int wordIndex = 0;
      int remainingWordBits = BITS_PER_WORD;
      long word = words[wordIndex];

      @Override
      public long next() {
        long register;
        if (remainingWordBits >= registerWidth) {
          register = word & registerMask;

          // shift to the next register
          word >>>= registerWidth;
          remainingWordBits -= registerWidth;
        } else {
          /*insufficient bits remaining in current word*/
          wordIndex++ /*move to the next word*/;

          register = (word | (words[wordIndex] << remainingWordBits)) & registerMask;

          // shift to the next partial register (word)
          word = words[wordIndex] >>> (registerWidth - remainingWordBits);
          remainingWordBits += BITS_PER_WORD - registerWidth;
        }
        registerIndex++;
        return register;
      }

      @Override
      public boolean hasNext() {
        return registerIndex < count;
      }
    };
  }

  // ------------------------------------------------------------------------
  // composite accessors
  /**
   * Sets the value of the specified index register if and only if the specified value is greater
   * than the current value in the register. This is equivalent to but much more performant than:
   *
   * <pre>vector.setRegister(index, Math.max(vector.getRegister(index), value));</pre>
   *
   * @param registerIndex the index of the register whose value is to be set. This cannot be
   *     negative
   * @param value the value to set in the register if and only if this value is greater than the
   *     current value in the register
   * @return <code>true</code> if and only if the specified value is greater than or equal to the
   *     current register value. <code>false</code> otherwise.
   * @see #getRegister(long)
   * @see #setRegister(long, long)
   * @see java.lang.Math#max(long, long)
   */
  // NOTE:  if this changes then setRegister() must change
  public boolean setMaxRegister(final long registerIndex, final long value) {
    final long bitIndex = registerIndex * registerWidth;
    final int firstWordIndex =
        (int) (bitIndex >>> LOG2_BITS_PER_WORD) /*aka (bitIndex / BITS_PER_WORD)*/;
    final int secondWordIndex =
        (int) ((bitIndex + registerWidth - 1) >>> LOG2_BITS_PER_WORD) /*see above*/;
    final int bitRemainder =
        (int) (bitIndex & BITS_PER_WORD_MASK) /*aka (bitIndex % BITS_PER_WORD)*/;

    // NOTE:  matches getRegister()
    final long registerValue;
    final long words[] = this.words /*for convenience/performance*/;
    if (firstWordIndex == secondWordIndex)
      registerValue = ((words[firstWordIndex] >>> bitRemainder) & registerMask);
    else /*register spans words*/
      registerValue =
          (words[firstWordIndex] >>> bitRemainder) /*no need to mask since at top of word*/
              | ((words[secondWordIndex] << (BITS_PER_WORD - bitRemainder)) & registerMask);

    // determine which is the larger and update as necessary
    if (value > registerValue) {
      // NOTE:  matches setRegister()
      if (firstWordIndex == secondWordIndex) {
        // clear then set
        words[firstWordIndex] &= ~(registerMask << bitRemainder);
        words[firstWordIndex] |= (value << bitRemainder);
      } else {
        /*register spans words*/
        // clear then set each partial word
        words[firstWordIndex] &= (1L << bitRemainder) - 1;
        words[firstWordIndex] |= (value << bitRemainder);

        words[secondWordIndex] &= ~(registerMask >>> (BITS_PER_WORD - bitRemainder));
        words[secondWordIndex] |= (value >>> (BITS_PER_WORD - bitRemainder));
      }
    } /* else -- the register value is greater (or equal) so nothing needs to be done */

    return (value >= registerValue);
  }

  // ========================================================================
  /**
   * Fills this bit vector with the specified bit value. This can be used to clear the vector by
   * specifying <code>0</code>.
   *
   * @param value the value to set all bits to (only the lowest bit is used)
   */
  public void fill(final long value) {
    for (long i = 0; i < count; i++) {
      setRegister(i, value);
    }
  }

  // ------------------------------------------------------------------------
  /**
   * Serializes the registers of the vector using the specified serializer.
   *
   * @param serializer the serializer to use. This cannot be <code>null</code>.
   */
  public void getRegisterContents(final IWordSerializer serializer) {
    for (final LongIterator iter = registerIterator(); iter.hasNext(); ) {
      serializer.writeWord(iter.next());
    }
  }

  /**
   * Creates a deep copy of this vector.
   *
   * @see java.lang.Object#clone()
   */
  @Override
  public BitVector clone() {
    final BitVector copy = new BitVector(registerWidth, count);
    System.arraycopy(words, 0, copy.words, 0, words.length);
    return copy;
  }
}
